{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from docx import Document\n",
    "# 读取word中表格，处理成表头加行的形式\n",
    "def read_table_from_word(file_path):\n",
    "    all_rows = []\n",
    "    doc = Document(file_path)\n",
    "    for i, table in enumerate(doc.tables):\n",
    "        for row in table.rows:\n",
    "            cells_text = [cell.text.replace('\\n', '') for cell in row.cells]\n",
    "            all_rows.append(cells_text)\n",
    "    new_all_rows=[]\n",
    "    i = 0\n",
    "    # 合并跨页的表格行\n",
    "    while i<len(all_rows):\n",
    "        if i==len(all_rows)-1:\n",
    "            new_all_rows.append(all_rows[i])\n",
    "            break\n",
    "        if all_rows[i+1][0]!='':\n",
    "            new_all_rows.append(all_rows[i])\n",
    "            i+=1\n",
    "        else:\n",
    "            new_all_rows.append([a+b for a, b in zip(all_rows[i],all_rows[i+1])])\n",
    "            i+=2\n",
    "            \n",
    "    final_rows=[]\n",
    "    flag = []\n",
    "    # 合并表头\n",
    "    for row in new_all_rows:\n",
    "        \n",
    "        if row[0]=='序号':\n",
    "            flag = row\n",
    "        merge_now = '|'.join(flag)+'\\n' + '|'.join(row)\n",
    "        final_rows.append(merge_now)\n",
    "\n",
    "    return final_rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_core.documents import Document as doc\n",
    "import os\n",
    "file_path = '/home/user/wyf/rag_table.docx'\n",
    "\n",
    "def get_docs(file_path):\n",
    "    text_chunks = read_table_from_word(file_path)\n",
    "    file_name = os.path.basename(file_path)\n",
    "    # 加入元数据\n",
    "    docs = [doc(page_content=chunk, metadata={'source':file_name}) for chunk in text_chunks]\n",
    "    return docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "splits = get_docs(file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "print(splits[1].page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_community.vectorstores import FAISS\n",
    "from langchain_huggingface import HuggingFaceEmbeddings\n",
    "from langchain_openai import ChatOpenAI\n",
    "embeddings = HuggingFaceEmbeddings(model_name='/home/user/wyf/fastchat/bge-large-zh-v1.5', model_kwargs = {'device': 'cuda:1'})\n",
    "llm = ChatOpenAI(temperature=0, model='/home/user/Downloads/Qwen2-7B-Instruct/', base_url='http://10.250.2.23:8600/v1', api_key='ww')\n",
    "vectorstore = FAISS.from_documents(splits, embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "retriever = vectorstore.as_retriever(search_kwargs={'k':3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "recall_text = retriever.get_relevant_documents('斑小将美白隔离防晒乳的检验结果')\n",
    "text = ''\n",
    "for i in recall_text:\n",
    "    text += i.page_content\n",
    "    text += '\\n\\n'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "print(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "prompt = '''任务目标：根据给定的知识，回答用户提出的问题。\n",
    "任务要求：\n",
    "1、不得脱离给定的知识进行回答。\n",
    "2、如果给定的知识不包含问题的答案，请回答我不知道。\n",
    "\n",
    "知识：\n",
    "{}\n",
    "\n",
    "用户问题：\n",
    "{}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "query = '斑小将美白隔离防晒乳的检验结果'\n",
    "print(llm.invoke(prompt.format(text,query)).content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import mammoth\n",
    "with open(\"/home/user/wyf/rag_test.docx\", \"rb\") as docx_file:\n",
    "    result = mammoth.convert_to_html(docx_file)\n",
    "    html = result.value \n",
    "    print(html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "with open('tmp.html', 'w', encoding='utf-8') as f:\n",
    "    f.write(html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_text_splitters import HTMLSectionSplitter\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "headers_to_split_on = [\n",
    "    (\"h1\", \"Header 1\"),\n",
    "    (\"h2\", \"Header 2\"),\n",
    "    (\"h3\", \"Header 3\"),\n",
    "]\n",
    "\n",
    "html_splitter = HTMLSectionSplitter(headers_to_split_on)\n",
    "\n",
    "html_header_splits = html_splitter.split_text(html)\n",
    "for i in html_header_splits:\n",
    "    print(i)\n",
    "    print(\"=======================================================\")\n",
    "# chunk_size = 500\n",
    "# chunk_overlap = 50\n",
    "# text_splitter = RecursiveCharacterTextSplitter(\n",
    "#     chunk_size=chunk_size, chunk_overlap=chunk_overlap\n",
    "# )\n",
    "\n",
    "# # Split\n",
    "# splits = text_splitter.split_documents(html_header_splits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "vectorstore = FAISS.from_documents(html_header_splits, embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "retriever = vectorstore.as_retriever(search_kwargs={'k':1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "recall_text = retriever.get_relevant_documents('国家药监局什么时候发布了《关于更新化妆品禁用原料目录的公告》')\n",
    "text = ''\n",
    "for i in recall_text:\n",
    "    text += i.page_content\n",
    "    text += '\\n\\n'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "query = '国家药监局什么时候发布了《关于更新化妆品禁用原料目录的公告》'\n",
    "print(llm.invoke(prompt.format(text,query)).content)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
